#!/usr/bin/env python3 IGNORE_ID = -1 def pad_list(xs, pad_value): # From: espnet/src/nets/e2e_asr_th.py: pad_list() n_batch = len(xs) max_len = max(x.size(0) for x in xs) pad = xs[0].new(n_batch, max_len, * xs[0].size()[1:]).fill_(pad_value) for i in range(n_batch): pad[i, :xs[i].size(0)] = xs[i] return pad # -- Transformer Related -- import torch def get_non_pad_mask(padded_input, input_lengths=None, pad_idx=None): """padding position is set to 0, either use input_lengths or pad_idx """ assert input_lengths is not None or pad_idx is not None if input_lengths is not None: # padded_input: N x T x .. N = padded_input.size(0) non_pad_mask = padded_input.new_ones(padded_input.size()[:-1]) # N x T for i in range(N): non_pad_mask[i, input_lengths[i]:] = 0 if pad_idx is not None: # padded_input: N x T assert padded_input.dim() == 2 non_pad_mask = padded_input.ne(pad_idx).float() # unsqueeze(-1) for broadcast return non_pad_mask.unsqueeze(-1) def get_subsequent_mask(seq): ''' For masking out the subsequent info. ''' sz_b, len_s = seq.size() subsequent_mask = torch.triu( torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls return subsequent_mask def get_attn_key_pad_mask(seq_k, seq_q, pad_idx): ''' For masking out the padding part of key sequence. ''' # Expand to fit the shape of key query attention matrix. len_q = seq_q.size(1) padding_mask = seq_k.eq(pad_idx) padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk return padding_mask def get_attn_pad_mask(padded_input, input_lengths, expand_length): """mask position is set to 1""" # N x Ti x 1 non_pad_mask = get_non_pad_mask(padded_input, input_lengths=input_lengths) # N x Ti, lt(1) like not operation pad_mask = non_pad_mask.squeeze(-1).lt(1) attn_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1) return attn_mask